# Reksio - Memory Map Editor
# Copyright (C) 2023 CERN
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# In applying this licence, CERN does not waive the privileges and immunities
# granted to it by virtue of its status as an Intergovernmental Organization or
# submit itself to any jurisdiction.

import reksio
import expand_submap

import PyCheby.utils as utils


def node_crawler(node):
    if node.parent is None:
        yield node
    for child in node.children():
        yield child


def SI_fixer(value):
    si_prefixes = {
        'k': 2 ** 10,
        'M': 2 ** 20,
        'G': 2 ** 30
    }
    if value[-1] in si_prefixes and value[:-1].isdigit():
        return int(value[:-1]) * si_prefixes[value[-1]]
    return int(value, 0)


def on_load(main_window, node):
    pass


def on_loaded(main_window):
    expand_submap.expand(main_window)
    addresser_action(main_window) # do the addressing


NODE_SEP = '_'


def onInsertNode(main_window, parent_index, first, last):
    addresser_action(main_window)


def onRemoveNode(main_window, parent_index, first, last):
    addresser_action(main_window)


def onMoveNode(main_window, parent_index, start, end, destination_index, row):
    addresser_action(main_window)


def menu_trigger_addresser_action(main_window):
    addresser_action(main_window)
    reksio.info(f"Manually triggered addressing done!")


def addresser_action(main_window):
    nodes_model = main_window.getNodesModel()
    root_node = nodes_model.getRoot()
    try:
        calc = AddressCalculator(root_node.children()[0])
        calc.execute()
    except Exception as e:
        reksio.warn(f"Addressing failed: {e}")


def set_user_size(node):
    node_size = getattr(node, 'size', None)
    if node_size is not None:
        user_size = SI_fixer(node_size.value)
        computed_size = int(node.getAttribute(["computed", "size"]).value)
        if user_size >= computed_size:
            node.getAttribute(["computed", "size"]).value = str(user_size)
        else:
            raise Exception(f"User defined size ({user_size}) is too small for {node}. Required size: {computed_size}")


def get_address_space(node):
    address_space = None

    if node.type == "submap":
        address_space = getattr(node, "address-space", None)

    if address_space is not None:
        return address_space.value

    return None


def register_addresser(node, address_calculator):
    try:
        width = int(node["width"].value, 0)
    except:
        raise Exception(f"width attribute not set in {node}")
    node.getAttribute(["computed", "size"]).value = str(width // 8)
    if address_calculator.align_regs:
        node.getAttribute(["computed", "align"]).value = str(
            utils.align(int(node.getAttribute(["computed", "size"]).value),
            address_calculator.word_size)
        )
    else:
        node.getAttribute(["computed", "align"]).value = str(address_calculator.word_size)


def block_addresser(node, address_calculator):
    if node.children():
        composite_addresser(node, address_calculator)
        set_user_size(node)
    else:
        node.getAttribute(["computed", "size"]).value = str(SI_fixer(node.size.value))

    align_block(node)


def array_addresser(node, address_calculator):
    composite_addresser(node, address_calculator)
    set_user_size(node)
    element_size = utils.align(int(node.getAttribute(["computed", "size"]).value),
                                int(node.getAttribute(["computed", "align"]).value))
    node_align = getattr(node, "align", True)
    repeat = SI_fixer(node.repeat.value)
    if node_align:
        element_size = utils.round_pow2(element_size)
        node.getAttribute(["computed", "size"]).value = str(element_size * repeat)
        node.getAttribute(["computed", "align"]).value = str(element_size * utils.round_pow2(repeat))
    else:
        node.getAttribute(["computed", "size"]).value = str(element_size * repeat)


def memory_addresser(node, address_calculator):
    composite_addresser(node, address_calculator)
    memsize = SI_fixer(node.memsize.value)
    element_size = int(node.getAttribute(["computed", "size"]).value)

    # align to power of 2
    element_size = utils.round_pow2(element_size)
    if memsize % element_size != 0:
        raise Exception(f"{node} memsize {memsize} is not a multiple of register's size")

    depth = memsize // element_size

    element_size = utils.align(int(node.getAttribute(["computed", "size"]).value),
                                int(node.getAttribute(["computed", "align"]).value))

    element_size = utils.round_pow2(element_size)

    node.getAttribute(["computed", "size"]).value = str(depth * element_size)
    node.getAttribute(["computed", "align"]).value = str(utils.round_pow2(int(node.getAttribute(["computed", "size"]).value)))

    set_user_size(node)
    align_block(node)


def address_space_addresser(node, address_calculator):
    composite_addresser(node, address_calculator)
    set_user_size(node)
    align_block(node)

def repeat_addresser(node, address_calculator):
    children = node.children()
    if len(children) != 1:
        raise Exception(f"{node} must have a single child")

    composite_addresser(node, address_calculator)
    count = int(node.count.value)

    element_size = utils.align(int(node.getAttribute(["computed", "size"]).value),
                                int(node.getAttribute(["computed", "align"]).value))

    node.getAttribute(["computed", "size"]).value = str(count * element_size)
    align_block(node)


def submap_addresser(node, address_calculator):
    has_filename = hasattr(node, "filename")
    children = node.children()
    if not children and has_filename:
        raise Exception(f"Please expand submaps ({node}) before addressing!")

    if children:
        composite_addresser(node, address_calculator)
        set_user_size(node)

    if not has_filename:
        node.getAttribute(["computed", "size"]).value = str(SI_fixer(node.size.value))

    align_block(node)


def root_addresser(node, address_calculator):
    node.getAttribute(["computed", "address"]).value = str(0)
    node.getAttribute(["computed", "offset_address"]).value = str(0)

    bus = node["bus"].value
    new_addr_calculator = AddressCalculator(node)
    if not bus.startswith('cern-be-vme-'):
        new_addr_calculator.align_regs = True

    composite_addresser(node, new_addr_calculator)
    set_user_size(node)


def composite_addresser(node, address_calculator):
    new_calculator = address_calculator.duplicate()
    max_align = 0

    for child in node.children():
        if child.type in ["reg", "block", "array", "memory", "repeat", "submap", "memory-map", "address-space"]:
            addresser(child, new_calculator)
            max_align = max(max_align, int(child.getAttribute(["computed", "align"]).value))

    node.getAttribute(["computed", "align"]).value = str(max_align)
    node.getAttribute(["computed", "size"]).value = str(0)

    for child in node.children():
        if child.type in ["reg", "block", "array", "memory", "repeat", "submap", "memory-map", "address-space"]:
            new_calculator.compute_address(child)
            node.getAttribute(["computed", "size"]).value = str(
                    max(int(node.getAttribute(["computed", "size"]).value),
                        int(child.getAttribute(["computed", "offset_address"]).value) + int(
                            child.getAttribute(["computed", "size"]).value)))


def align_block(node):
    node_align = getattr(node, "align", "true")
    if node_align == "true":
        node.getAttribute(["computed", "size"]).value = str(
            utils.round_pow2(int(node.getAttribute(["computed", "size"]).value)))
        node.getAttribute(["computed", "align"]).value = str(
            utils.round_pow2(int(node.getAttribute(["computed", "size"]).value)))


def set_abs_address(node, base_address):
    if node.type not in ["reg", "block", "array", "submap", "memory-map", "memory", "repeat", "address-space"]:
        return

    node.getAttribute(["computed", "address"]).value = str(
        base_address + int(node.getAttribute(["computed", "offset_address"]).value))
    reksio.get_main_window().dataChanged(node.getAttribute(["computed", "address"]))

    # From Reksio v2.0.0 the base address of the Address-Space node is hardcoded to 0x0 in order
    # to follow the "cheby" tool convention.
    if node.type == 'address-space':
        node.getAttribute(["computed", "address"]).value = str(0)

    if node.type == "reg":
        pass
    elif node.type in ["array", "memory", "repeat"]:
        for child in node.children():
            set_abs_address(child, 0)
    elif node.type in ["memory-map", "block", "submap", "address-space"]:
        for child in node.children():
            offset = int(node.getAttribute(["computed", "address"]).value)
            if child.type == 'address-space':
                offset = 0
            set_abs_address(child, offset)
    else:
        raise Exception("not handled")


def add_attrib_if_not_exists(attr_container, attr_name):
    attr = attr_container.getAttribute(attr_name)
    if attr is None:
        new_attr = reksio.Attribute(attr_name)
        new_attr.savable = False
        attr_container.addAttribute(new_attr)
        attr = attr_container.getAttribute(attr_name)

    reksio.get_main_window().dataChanged(attr)


def add_attrib_container_if_not_exists(node, container_name):
    root_container = node.attribute_container()
    new_container = root_container.getAttributeContainer(container_name)
    if new_container is None:
        root_container.addAttributeContainer(container_name)


def addresser(node, address_calculator):
    if node.type in ["reg", "block", "array", "submap", "memory-map", "memory", "repeat", "address-space"]:
        add_attrib_container_if_not_exists(node, "computed")
        computed_attr_container = node.attribute_container().getAttributeContainer("computed")
        add_attrib_if_not_exists(computed_attr_container, "address")
        add_attrib_if_not_exists(computed_attr_container, "size")
        add_attrib_if_not_exists(computed_attr_container, "align")
        add_attrib_if_not_exists(computed_attr_container, "offset_address")

        if node.type == "reg":
            register_addresser(node, address_calculator)
        elif node.type == "block":
            block_addresser(node, address_calculator)
        elif node.type == "array":
            array_addresser(node, address_calculator)
        elif node.type == "memory":
            memory_addresser(node, address_calculator)
        elif node.type == "repeat":
            repeat_addresser(node, address_calculator)
        elif node.type == "address-space":
            address_space_addresser(node, address_calculator)
        elif node.type == "submap":
            submap_addresser(node, address_calculator)
        elif node.type == "memory-map":
            root_addresser(node, address_calculator)
        else:
            raise Exception("Unhandled node type")


class AddressCalculator:
    """
    This class does addressing for a memory map
    """
    addressers = {}

    def word_size(self):
        try:
            bus = self.root["bus"].value
        except:
            raise Exception("bus attribute in memory-map is not specified")
        if bus == "wb-32-be":
            return 4
        elif bus == "wb-16-be":
            return 2
        elif bus == "axi4-lite-32":
            return 4
        elif bus.startswith("cern-be-vme-"):
            return int(bus.split('-')[-1]) // 8
        else:
            raise Exception(f"Unknown bus: {bus}")

    def __init__(self, root, address_space=None):
        self.root = root
        self.align_regs = False
        self.word_size = self.word_size()
        self.address = 0

    def duplicate(self):
        calculator = AddressCalculator(self.root)
        calculator.align_regs = self.align_regs
        return calculator

    def compute_address(self, node):
        node_size = int(node.getAttribute(["computed", "size"]).value)

        node_addr = getattr(node, 'address', None)
        if node_addr is None or node_addr.value == 'next':
            new_address = utils.align(self.address, int(node.getAttribute(["computed", "align"]).value))
        else:
            new_address = int(node_addr.value, 0)

        if new_address < self.address:
            reksio.warn(f"The address of node {node} is overlapping with other one!")

        self.address = new_address
        node.getAttribute(["computed", "offset_address"]).value = str(self.address)
        self.address += node_size

    def addresser(self, node):
        return addresser(node, self)

    def execute(self):
        """
        Does the addressing for a memory map
        """
        root_node = self.root
        self.addresser(root_node)
        set_abs_address(root_node, 0)
